import os
import argparse
import datetime
import hashlib
import shutil
import numpy as np

import torch
import torch.nn as nn
import cls.dataset.luna
import cls.net
import cls.aug

from matplotlib import pyplot
from torch.utils.tensorboard import SummaryWriter
from torch.optim import SGD
from torch.utils.data import DataLoader
from config import DATA_PATH
from utils.enumerate import enumerateWithEstimate
from utils.logconf import logging

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
METRICS_LABEL_NDX = 0
METRICS_PRED_NDX = 1
METRICS_PRED_P_NDX = 2
METRICS_LOSS_NDX = 3
METRICS_SIZE = 4


class ClassificationTrainingApp:
    def __init__(self, cli_args):
        self.cli_args = cli_args
        self.time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H.%M.%S")

        self.trn_writer = None
        self.val_writer = None
        self.totalTrainingSamples_count = 0

        self.augmentation_dict = {}
        if True:
            # if self.cli_args.augmented or self.cli_args.augment_flip:
            self.augmentation_dict["flip"] = True
            # if self.cli_args.augmented or self.cli_args.augment_offset:
            self.augmentation_dict["offset"] = 0.1
            # if self.cli_args.augmented or self.cli_args.augment_scale:
            self.augmentation_dict["scale"] = 0.2
            # if self.cli_args.augmented or self.cli_args.augment_rotate:
            self.augmentation_dict["rotate"] = True
            # if self.cli_args.augmented or self.cli_args.augment_noise:
            self.augmentation_dict["noise"] = 25.0

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")

        self.model = self.initModel()
        self.optimizer = self.initOptimizer()

    def initModel(self):
        model_cls = getattr(cls.net, self.cli_args.model)
        model = model_cls()

        if self.cli_args.finetune:
            d = torch.load(self.cli_args.finetune, map_location="cpu")
            model_blocks = [
                n
                for n, subm in model.named_children()
                if len(list(subm.parameters())) > 0
            ]
            finetune_blocks = model_blocks[-self.cli_args.finetune_depth :]
            log.info(
                f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}"
            )
            model.load_state_dict(
                {
                    k: v
                    for k, v in d["model_state"].items()
                    if k.split(".")[0] not in model_blocks[-1]
                },
                strict=False,
            )
            for n, p in model.named_parameters():
                if n.split(".")[0] not in finetune_blocks:
                    p.requires_grad_(False)
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                model = nn.DataParallel(model)
            model = model.to(self.device)
        return model

    def initOptimizer(self):
        lr = 0.003 if self.cli_args.finetune else 0.001
        return SGD(self.model.parameters(), lr=lr, weight_decay=1e-4)
        # return Adam(self.model.parameters(), lr=3e-4)

    def initTrainDl(self):
        ds_cls = getattr(cls.dataset.luna, self.cli_args.dataset)

        train_ds = ds_cls(
            val_stride=10,
            isValSet_bool=False,
            ratio_int=1,
        )

        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()

        train_dl = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return train_dl

    def initValDl(self):
        ds_cls = getattr(cls.dataset.luna, self.cli_args.dataset)

        val_ds = ds_cls(
            val_stride=10,
            isValSet_bool=True,
        )

        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()

        val_dl = DataLoader(
            val_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return val_dl

    def initTensorboardWriters(self):
        if self.trn_writer is None:
            log_dir = os.path.join("runs", "cls", self.time_str)

            self.trn_writer = SummaryWriter(
                log_dir=log_dir + "-trn_cls-" + self.cli_args.comment
            )
            self.val_writer = SummaryWriter(
                log_dir=log_dir + "-val_cls-" + self.cli_args.comment
            )

    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        train_dl = self.initTrainDl()
        val_dl = self.initValDl()

        best_score = 0.0
        validation_cadence = 5 if not self.cli_args.finetune else 1
        for epoch_ndx in range(1, self.cli_args.epochs + 1):
            log.info(
                "Epoch {} of {}, {}/{} batches of size {}*{}".format(
                    epoch_ndx,
                    self.cli_args.epochs,
                    len(train_dl),
                    len(val_dl),
                    self.cli_args.batch_size,
                    (torch.cuda.device_count() if self.use_cuda else 1),
                )
            )

            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, "trn", trnMetrics_t)

            if epoch_ndx == 1 or epoch_ndx % validation_cadence == 0:
                valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                score = self.logMetrics(epoch_ndx, "val", valMetrics_t)
                best_score = max(score, best_score)

                # TODO: this 'cls' will need to change for the malignant classifier
                self.saveModel("cls", epoch_ndx, score == best_score)

        if hasattr(self, "trn_writer"):
            self.trn_writer.close()
            self.val_writer.close()

    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        train_dl.dataset.shuffleSamples()
        trnMetrics_g = torch.zeros(
            METRICS_SIZE,
            len(train_dl.dataset),
            device=self.device,
        )

        batch_iter = enumerateWithEstimate(
            train_dl,
            "E{} Training".format(epoch_ndx),
            start_ndx=train_dl.num_workers,
        )
        for batch_ndx, batch_tup in batch_iter:
            self.optimizer.zero_grad()

            loss_var = self.computeBatchLoss(
                batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g, augment=True
            )

            loss_var.backward()
            self.optimizer.step()

        self.totalTrainingSamples_count += len(train_dl.dataset)

        return trnMetrics_g.to("cpu")

    def doValidation(self, epoch_ndx, val_dl):
        with torch.no_grad():
            self.model.eval()
            valMetrics_g = torch.zeros(
                METRICS_SIZE,
                len(val_dl.dataset),
                device=self.device,
            )

            batch_iter = enumerateWithEstimate(
                val_dl,
                "E{} Validation ".format(epoch_ndx),
                start_ndx=val_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(
                    batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g, augment=False
                )

        return valMetrics_g.to("cpu")

    def computeBatchLoss(
        self, batch_ndx, batch_tup, batch_size, metrics_g, augment=True
    ):
        input_t, label_t, index_t, _series_list, _center_list = batch_tup

        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)
        index_g = index_t.to(self.device, non_blocking=True)

        if augment:
            input_g = cls.aug.augment3d(input_g)

        logits_g, probability_g = self.model(input_g)

        loss_g = nn.functional.cross_entropy(logits_g, label_g[:, 1], reduction="none")
        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx + label_t.size(0)

        _, predLabel_g = torch.max(probability_g, dim=1, keepdim=False, out=None)

        # log.debug(index_g)

        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = index_g
        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = predLabel_g
        # metrics_g[METRICS_PRED_N_NDX, start_ndx:end_ndx] = probability_g[:,0]
        metrics_g[METRICS_PRED_P_NDX, start_ndx:end_ndx] = probability_g[:, 1]
        # metrics_g[METRICS_PRED_M_NDX, start_ndx:end_ndx] = probability_g[:,2]
        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g

        return loss_g.mean()

    def logMetrics(
        self,
        epoch_ndx,
        mode_str,
        metrics_t,
        classificationThreshold=0.5,
    ):
        self.initTensorboardWriters()
        log.info(
            "E{} {}".format(
                epoch_ndx,
                type(self).__name__,
            )
        )

        if self.cli_args.dataset == "MalignantLunaDataset":
            pos = "mal"
            neg = "ben"
        else:
            pos = "pos"
            neg = "neg"

        negLabel_mask = metrics_t[METRICS_LABEL_NDX] == 0
        negPred_mask = metrics_t[METRICS_PRED_NDX] == 0

        posLabel_mask = ~negLabel_mask
        posPred_mask = ~negPred_mask

        # benLabel_mask = metrics_t[METRICS_LABEL_NDX] == 1
        # benPred_mask = metrics_t[METRICS_PRED_NDX] == 1
        #
        # malLabel_mask = metrics_t[METRICS_LABEL_NDX] == 2
        # malPred_mask = metrics_t[METRICS_PRED_NDX] == 2

        # benLabel_mask = ~malLabel_mask & posLabel_mask
        # benPred_mask = ~malPred_mask & posLabel_mask

        neg_count = int(negLabel_mask.sum())
        pos_count = int(posLabel_mask.sum())
        # ben_count = int(benLabel_mask.sum())
        # mal_count = int(malLabel_mask.sum())

        neg_correct = int((negLabel_mask & negPred_mask).sum())
        pos_correct = int((posLabel_mask & posPred_mask).sum())
        # ben_correct = int((benLabel_mask & benPred_mask).sum())
        # mal_correct = int((malLabel_mask & malPred_mask).sum())

        trueNeg_count = neg_correct
        truePos_count = pos_correct

        falsePos_count = neg_count - neg_correct
        falseNeg_count = pos_count - pos_correct

        metrics_dict = {}
        metrics_dict["loss/all"] = metrics_t[METRICS_LOSS_NDX].mean()
        metrics_dict["loss/neg"] = metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
        metrics_dict["loss/pos"] = metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
        # metrics_dict['loss/ben'] = metrics_t[METRICS_LOSS_NDX, benLabel_mask].mean()
        # metrics_dict['loss/mal'] = metrics_t[METRICS_LOSS_NDX, malLabel_mask].mean()

        metrics_dict["correct/all"] = (
            (pos_correct + neg_correct) / metrics_t.shape[1] * 100
        )
        metrics_dict["correct/neg"] = (neg_correct) / neg_count * 100
        metrics_dict["correct/pos"] = (pos_correct) / pos_count * 100
        # metrics_dict['correct/ben'] = (ben_correct) / ben_count * 100
        # metrics_dict['correct/mal'] = (mal_correct) / mal_count * 100

        precision = metrics_dict["pr/precision"] = truePos_count / np.float64(
            truePos_count + falsePos_count
        )
        recall = metrics_dict["pr/recall"] = truePos_count / np.float64(
            truePos_count + falseNeg_count
        )

        metrics_dict["pr/f1_score"] = 2 * (precision * recall) / (precision + recall)

        # https://pytorch.org/docs/stable/generated/torch.linspace.html
        # From PyTorch 1.11 linspace requires the steps argument. Use steps=100 to restore the previous behavior.
        threshold = torch.linspace(1, 0, steps=100)
        tpr = (
            metrics_t[None, METRICS_PRED_P_NDX, posLabel_mask] >= threshold[:, None]
        ).sum(1).float() / pos_count
        fpr = (
            metrics_t[None, METRICS_PRED_P_NDX, negLabel_mask] >= threshold[:, None]
        ).sum(1).float() / neg_count
        fp_diff = fpr[1:] - fpr[:-1]
        tp_avg = (tpr[1:] + tpr[:-1]) / 2
        auc = (fp_diff * tp_avg).sum()
        metrics_dict["auc"] = auc

        log.info(
            (
                "E{} {:8} {loss/all:.4f} loss, "
                + "{correct/all:-5.1f}% correct, "
                + "{pr/precision:.4f} precision, "
                + "{pr/recall:.4f} recall, "
                + "{pr/f1_score:.4f} f1 score, "
                + "{auc:.4f} auc"
            ).format(
                epoch_ndx,
                mode_str,
                **metrics_dict,
            )
        )
        log.info(
            (
                "E{} {:8} {loss/neg:.4f} loss, "
                + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
            ).format(
                epoch_ndx,
                mode_str + "_" + neg,
                neg_correct=neg_correct,
                neg_count=neg_count,
                **metrics_dict,
            )
        )
        log.info(
            (
                "E{} {:8} {loss/pos:.4f} loss, "
                + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
            ).format(
                epoch_ndx,
                mode_str + "_" + pos,
                pos_correct=pos_correct,
                pos_count=pos_count,
                **metrics_dict,
            )
        )
        writer = getattr(self, mode_str + "_writer")

        for key, value in metrics_dict.items():
            key = key.replace("pos", pos)
            key = key.replace("neg", neg)
            writer.add_scalar(key, value, self.totalTrainingSamples_count)

        fig = pyplot.figure()
        pyplot.plot(fpr, tpr)
        writer.add_figure("roc", fig, self.totalTrainingSamples_count)

        writer.add_scalar("auc", auc, self.totalTrainingSamples_count)

        bins = np.linspace(0, 1)

        writer.add_histogram(
            "label_neg",
            metrics_t[METRICS_PRED_P_NDX, negLabel_mask],
            self.totalTrainingSamples_count,
            bins=bins,
        )
        writer.add_histogram(
            "label_pos",
            metrics_t[METRICS_PRED_P_NDX, posLabel_mask],
            self.totalTrainingSamples_count,
            bins=bins,
        )

        if not self.cli_args.malignant:
            score = metrics_dict["pr/f1_score"]
        else:
            score = metrics_dict["auc"]

        return score

    def saveModel(self, type_str, epoch_ndx, isBest=False):
        file_path = os.path.join(
            DATA_PATH,
            "models",
            type_str,
            "{}_{}_{}.{}.state".format(
                type_str,
                self.time_str,
                self.cli_args.comment,
                self.totalTrainingSamples_count,
            ),
        )

        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)

        model = self.model
        if isinstance(model, torch.nn.DataParallel):
            model = model.module

        state = {
            "model_state": model.state_dict(),
            "model_name": type(model).__name__,
            "optimizer_state": self.optimizer.state_dict(),
            "optimizer_name": type(self.optimizer).__name__,
            "epoch": epoch_ndx,
            "totalTrainingSamples_count": self.totalTrainingSamples_count,
        }
        torch.save(state, file_path)

        log.debug("Saved model params to {}".format(file_path))

        if isBest:
            best_path = os.path.join(
                DATA_PATH,
                "models",
                type_str,
                "{}_{}_{}.{}.state".format(
                    type_str,
                    self.time_str,
                    self.cli_args.comment,
                    "best",
                ),
            )
            shutil.copyfile(file_path, best_path)

            log.debug("Saved model params to {}".format(best_path))

        with open(file_path, "rb") as f:
            log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())


def parse_args():
    parser = argparse.ArgumentParser()

    # fmt: off
    parser.add_argument("--batch-size", default=24, type=int, help="Batch size to use for training")
    parser.add_argument("--num-workers", default=8, type=int, help="Number of worker processes for background data loading")
    parser.add_argument("--epochs", default=1, type=int, help="Number of epochs to train for")
    parser.add_argument("--dataset", action="store", default="LunaDataset", help="What to dataset to feed the model.")
    parser.add_argument("--model", action="store", default="LunaModel", help="What to model class name to use.")
    parser.add_argument("--malignant", action="store_true", default=False, help="Train the model to classify nodules as benign or malignant.")
    parser.add_argument("--finetune", default="", help="Start finetuning from this model.")
    parser.add_argument("--finetune-depth", type=int, default=1, help="Number of blocks (counted from the head) to include in finetuning")
    parser.add_argument("comment", nargs="?", default="dlwpt", help="Comment suffix for Tensorboard run.")
    # fmt: on

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    ClassificationTrainingApp(args).main()
